{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_label: Google Cloud Vertex AI\n",
    "keywords: [gemini, vertex, ChatVertexAI, gemini-pro]\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ChatVertexAI\n",
    "\n",
    "Note: This is separate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n",
    "\n",
    "ChatVertexAI exposes all foundational models available in Google Cloud:\n",
    "\n",
    "- Gemini (`gemini-pro` and `gemini-pro-vision`)\n",
    "- PaLM 2 for Text (`text-bison`)\n",
    "- Codey for Code Generation (`codechat-bison`)\n",
    "\n",
    "For a full and updated list of available models visit [VertexAI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/overview).\n",
    "\n",
    "By default, Google Cloud [does not use](https://cloud.google.com/vertex-ai/docs/generative-ai/data-governance#foundation_model_development) customer data to train its foundation models as part of Google Cloud`s AI/ML Privacy Commitment. More details about how Google processes data can also be found in [Google's Customer Data Processing Addendum (CDPA)](https://cloud.google.com/terms/data-processing-addendum).\n",
    "\n",
    "To use `Google Cloud Vertex AI` PaLM you must have the `langchain-google-vertexai` Python package installed and either:\n",
    "- Have credentials configured for your environment (gcloud, workload identity, etc...)\n",
    "- Store the path to a service account JSON file as the GOOGLE_APPLICATION_CREDENTIALS environment variable\n",
    "\n",
    "This codebase uses the `google.auth` library which first looks for the application credentials variable mentioned above, and then looks for system-level auth.\n",
    "\n",
    "For more information, see: \n",
    "- https://cloud.google.com/docs/authentication/application-default-credentials#GAC\n",
    "- https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#module-google.auth\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet  langchain-google-vertexai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_google_vertexai import ChatVertexAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\" J'aime la programmation.\")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "system = \"You are a helpful assistant who translate English to French\"\n",
    "human = \"Translate this sentence from English to French. I love programming.\"\n",
    "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
    "\n",
    "chat = ChatVertexAI()\n",
    "\n",
    "chain = prompt | chat\n",
    "chain.invoke({})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gemini doesn't support SystemMessage at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to `True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=\"J'aime la programmation.\")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "system = \"You are a helpful assistant who translate English to French\"\n",
    "human = \"Translate this sentence from English to French. I love programming.\"\n",
    "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
    "\n",
    "chat = ChatVertexAI(model_name=\"gemini-pro\", convert_system_message_to_human=True)\n",
    "\n",
    "chain = prompt | chat\n",
    "chain.invoke({})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we want to construct a simple chain that takes user specified parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=' プログラミングが大好きです')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "system = (\n",
    "    \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
    ")\n",
    "human = \"{text}\"\n",
    "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
    "\n",
    "chat = ChatVertexAI()\n",
    "\n",
    "chain = prompt | chat\n",
    "\n",
    "chain.invoke(\n",
    "    {\n",
    "        \"input_language\": \"English\",\n",
    "        \"output_language\": \"Japanese\",\n",
    "        \"text\": \"I love programming\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Code generation chat models\n",
    "You can now leverage the Codey API for code chat within Vertex AI. The model available is:\n",
    "- `codechat-bison`: for code assistance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " ```python\n",
      "def is_prime(n):\n",
      "  \"\"\"\n",
      "  Check if a number is prime.\n",
      "\n",
      "  Args:\n",
      "    n: The number to check.\n",
      "\n",
      "  Returns:\n",
      "    True if n is prime, False otherwise.\n",
      "  \"\"\"\n",
      "\n",
      "  # If n is 1, it is not prime.\n",
      "  if n == 1:\n",
      "    return False\n",
      "\n",
      "  # Iterate over all numbers from 2 to the square root of n.\n",
      "  for i in range(2, int(n ** 0.5) + 1):\n",
      "    # If n is divisible by any number from 2 to its square root, it is not prime.\n",
      "    if n % i == 0:\n",
      "      return False\n",
      "\n",
      "  # If n is divisible by no number from 2 to its square root, it is prime.\n",
      "  return True\n",
      "\n",
      "\n",
      "def find_prime_numbers(n):\n",
      "  \"\"\"\n",
      "  Find all prime numbers up to a given number.\n",
      "\n",
      "  Args:\n",
      "    n: The upper bound for the prime numbers to find.\n",
      "\n",
      "  Returns:\n",
      "    A list of all prime numbers up to n.\n",
      "  \"\"\"\n",
      "\n",
      "  # Create a list of all numbers from 2 to n.\n",
      "  numbers = list(range(2, n + 1))\n",
      "\n",
      "  # Iterate over the list of numbers and remove any that are not prime.\n",
      "  for number in numbers:\n",
      "    if not is_prime(number):\n",
      "      numbers.remove(number)\n",
      "\n",
      "  # Return the list of prime numbers.\n",
      "  return numbers\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "chat = ChatVertexAI(\n",
    "    model_name=\"codechat-bison\", max_output_tokens=1000, temperature=0.5\n",
    ")\n",
    "\n",
    "message = chat.invoke(\"Write a Python function generating all prime numbers\")\n",
    "print(message.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full generation info\n",
    "\n",
    "We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just chat completions\n",
    "\n",
    "Note that the `generation_info` will be different depending if you're using a gemini model or not.\n",
    "\n",
    "### Gemini model\n",
    "\n",
    "`generation_info` will include:\n",
    "\n",
    "- `is_blocked`: whether generation was blocked or not\n",
    "- `safety_ratings`: safety ratings' categories and probability labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'is_blocked': False,\n",
      " 'safety_ratings': [{'category': 'HARM_CATEGORY_HARASSMENT',\n",
      "                     'probability_label': 'NEGLIGIBLE'},\n",
      "                    {'category': 'HARM_CATEGORY_HATE_SPEECH',\n",
      "                     'probability_label': 'NEGLIGIBLE'},\n",
      "                    {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT',\n",
      "                     'probability_label': 'NEGLIGIBLE'},\n",
      "                    {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT',\n",
      "                     'probability_label': 'NEGLIGIBLE'}]}\n"
     ]
    }
   ],
   "source": [
    "from pprint import pprint\n",
    "\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory\n",
    "\n",
    "human = \"Translate this sentence from English to French. I love programming.\"\n",
    "messages = [HumanMessage(content=human)]\n",
    "\n",
    "\n",
    "chat = ChatVertexAI(\n",
    "    model_name=\"gemini-pro\",\n",
    "    safety_settings={\n",
    "        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE\n",
    "    },\n",
    ")\n",
    "\n",
    "result = chat.generate([messages])\n",
    "pprint(result.generations[0][0].generation_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Non-gemini model\n",
    "\n",
    "`generation_info` will include:\n",
    "\n",
    "- `is_blocked`: whether generation was blocked or not\n",
    "- `safety_attributes`: a dictionary mapping safety attributes to their scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'is_blocked': False,\n",
      " 'safety_attributes': {'Derogatory': 0.1,\n",
      "                       'Finance': 0.3,\n",
      "                       'Insult': 0.1,\n",
      "                       'Sexual': 0.1}}\n"
     ]
    }
   ],
   "source": [
    "chat = ChatVertexAI()  # default is `chat-bison`\n",
    "\n",
    "result = chat.generate([messages])\n",
    "pprint(result.generations[0][0].generation_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function Calling with Gemini\n",
    "\n",
    "We can call Gemini models with tools."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MyModel(name='Erick', age=27)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.pydantic_v1 import BaseModel\n",
    "from langchain_google_vertexai import create_structured_runnable\n",
    "\n",
    "llm = ChatVertexAI(model_name=\"gemini-pro\")\n",
    "\n",
    "\n",
    "class MyModel(BaseModel):\n",
    "    name: str\n",
    "    age: int\n",
    "\n",
    "\n",
    "chain = create_structured_runnable(MyModel, llm)\n",
    "chain.invoke(\"My name is Erick and I'm 27 years old\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Asynchronous calls\n",
    "\n",
    "We can make asynchronous calls via the Runnables [Async Interface](/docs/expression_language/interface)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for running these examples in the notebook:\n",
    "import asyncio\n",
    "\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content=' अहं प्रोग्रामनं प्रेमामि')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "system = (\n",
    "    \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
    ")\n",
    "human = \"{text}\"\n",
    "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
    "\n",
    "chat = ChatVertexAI(model_name=\"chat-bison\", max_output_tokens=1000, temperature=0.5)\n",
    "chain = prompt | chat\n",
    "\n",
    "asyncio.run(\n",
    "    chain.ainvoke(\n",
    "        {\n",
    "            \"input_language\": \"English\",\n",
    "            \"output_language\": \"Sanskrit\",\n",
    "            \"text\": \"I love programming\",\n",
    "        }\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Streaming calls\n",
    "\n",
    "We can also stream outputs via the `stream` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " The five most populous countries in the world are:\n",
      "1. China (1.4 billion)\n",
      "2. India (1.3 billion)\n",
      "3. United States (331 million)\n",
      "4. Indonesia (273 million)\n",
      "5. Pakistan (220 million)"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [(\"human\", \"List out the 5 most populous countries in the world\")]\n",
    ")\n",
    "\n",
    "chat = ChatVertexAI()\n",
    "\n",
    "chain = prompt | chat\n",
    "\n",
    "for chunk in chain.stream({}):\n",
    "    sys.stdout.write(chunk.content)\n",
    "    sys.stdout.flush()"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "common-cpu.m108",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cpu:m108"
  },
  "kernelspec": {
   "display_name": "",
   "name": ""
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
